{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Antibody Generation with AntibodyGPT\n",
        "This notebook is a companion of chapter 7 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.  \n",
        "The code in this notebook is to generate antibody sequences using the [AntibodyGPT](https://huggingface.co/AntibodyGeneration/fine-tuned-progen2-small) model. It requires hardware acceleration.  \n",
        "More details about the code can be found in the related book's chapter."
      ],
      "metadata": {
        "id": "_CK9070kfjWE"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Downgrading the HF's Transformers library to ensure compatibility with the AntibodyGPT2's `ProGenForCausalLM` class, as it inherits from `PreTrainedModel`, which, starting from Transformers release 4.50, wouldn't inherit from `GenerationMixin` anymore, in so loosing the availability of the `generate` method."
      ],
      "metadata": {
        "id": "gLmyzX4VMjgH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install transformers==4.49.0"
      ],
      "metadata": {
        "id": "YGfnxHmKIlYo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Clone the official repo."
      ],
      "metadata": {
        "id": "MasdCI-YhW0y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/joethequant/docker_protein_generator.git\n",
        "%cd ./docker_protein_generator/"
      ],
      "metadata": {
        "id": "7t9pkqdStJai"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download one of the pretrained models and the associated tokenizer from the HF's Hub. Please note that the AutoClass to use is the custom ```ProGenForCausalLM``` available in the ```docker_protein_generator``` cloned repo.\n",
        "\n"
      ],
      "metadata": {
        "id": "YqvLKJtLhbAm"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yhPz5wH1p_zq"
      },
      "outputs": [],
      "source": [
        "from models.progen.modeling_progen import ProGenForCausalLM\n",
        "import torch\n",
        "from tokenizers import Tokenizer\n",
        "\n",
        "model_path = 'AntibodyGeneration/fine-tuned-progen2-small'\n",
        "\n",
        "model = ProGenForCausalLM.from_pretrained(model_path)\n",
        "tokenizer = Tokenizer.from_pretrained(model_path)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from pathlib import Path\n",
        "\n",
        "models_path = Path(\"antibodygen\")\n",
        "model.save_pretrained(models_path)"
      ],
      "metadata": {
        "id": "lYCVjDA8eKnc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The save model to disk is 588.6 MB (617 MB in memory after download)."
      ],
      "metadata": {
        "id": "q9dv5l-hfIYa"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define a target antigen sequence and the number of antibody sequences you want to generate for it and then start the generation process."
      ],
      "metadata": {
        "id": "Nw446mYAiBDP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "target_sequence = 'MQIPQAPWPVVWAVLQLGWRPGWFLDSPDRPWNPPTFSPALLVVTEGDNATFTCSFSNTSESFVLNWYRMSPSNQTDKLAAFPEDRSQPGQDCRFRVTQLPNGRDFHMSVVRARRNDSGTYLCGAISLAPKAQIKESLRAELRVTERRAEVPTAHPSPSPRPAGQFQTLVVGVVGGLLGSLVLLVWVLAVICSRAARGTIGARRTGQPLKEDPSAVPVFSVDYGELDFQWREKTPEPPVPCVPEQTEYATIVFPSGMGTSSPARRGSADGPRSAQPLRPEDGHCSWPL'\n",
        "number_of_sequences = 2"
      ],
      "metadata": {
        "id": "lqBaTTrruwsl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Tokenize the prompt sequence and then convert it to PyTorch tensor and move it to the GPU."
      ],
      "metadata": {
        "id": "eEtma_bxiVwA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
        "tokenized_sequence = tokenizer.encode(target_sequence)\n",
        "input_tensor = torch.tensor([tokenized_sequence.ids]).to(device)"
      ],
      "metadata": {
        "id": "XQZMB5ZaiWIA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Move the model to GPU."
      ],
      "metadata": {
        "id": "y_nzxqiUiu_q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model = model.to(device)"
      ],
      "metadata": {
        "id": "dYJtRW1tN4Pu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Start the sequence generation."
      ],
      "metadata": {
        "id": "6C73SQG9i_Te"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with torch.no_grad():\n",
        "    output = model.generate(input_tensor, max_length=1024,\n",
        "                            pad_token_id=tokenizer.encode('<|pad|>').ids[0],\n",
        "\t                          do_sample=True, top_p=0.9, temperature=0.8,\n",
        "\t                          num_return_sequences=number_of_sequences)\n"
      ],
      "metadata": {
        "id": "8M336VkyyV9j"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Decode the generated sequences and display them."
      ],
      "metadata": {
        "id": "Ev2BO1GBjNzx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "as_lists = lambda batch: [batch[i, ...].detach().cpu().numpy().tolist() for i in range(batch.shape[0])]\n",
        "sequences = tokenizer.decode_batch(as_lists(output))\n",
        "if len(sequences) > 0:\n",
        "    sequences = [x.replace('2', '') for x in sequences]"
      ],
      "metadata": {
        "id": "D2I8U1rPK-AZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "sequences"
      ],
      "metadata": {
        "id": "Qz_VvGviLM3q"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}